Biomarker & Data Insights
Analysis Report

This demo aims at providing an overview of applying Random Survival Forests on simulated data. We are splitting the demo in the following steps

  1. Data preprocessing
  2. Hyperparameter tuning
  3. Performace evaluation
  4. VIMP (Variable Importance)
  5. PDPs (Partial Dependency Plots)

Dataset simulation:

We simulate data and also right-censored survival outcome, using a proportional hazard model with time-constant baseline hazard, such as

\(Surv(time,status) \sim HCT+BPSYS+trt+trt:BMI\)

Where:

  • Main effects
  • trt:Treatment
  • HCT: Hematocrit
  • BPSYS:Systolic blood pressure
  • Interaction with treatment
  • BMI: Body mass index

The dataset is small on purpose, so the knitting runtime is short. Still some chunks might take a bit longer to run . We set these chunks to eval=FALSE and saved their outcomes in .qs files. Those files are being read in the next chunk. If one wants to run these chunks please set eval to TRUE

# load library TV
library(tidyverse)
#load utility packages
library(kableExtra)
library(here)
library(tictoc)
library(stringr)
library(qs)
library(skimr)
library(ggplot2)

Data import

data<-qs::qread(here::here("Data","Demo_data_tte.qs"))

Exploring the data

data%>% skimr::skim()
Data summary
Name Piped data
Number of rows 289
Number of columns 38
_______________________
Column type frequency:
factor 8
numeric 30
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
.trt 0 1.00 FALSE 2 TRT: 153, PLC: 136
SEX 0 1.00 FALSE 2 M: 148, F: 141
RACE 20 0.93 FALSE 3 WHI: 199, ASI: 40, BLA: 30
atrial_fibrillation 0 1.00 FALSE 2 yes: 176, no: 113
myocardial_infarction 0 1.00 FALSE 2 no: 219, yes: 70
coronary_artery_disease 0 1.00 FALSE 2 no: 243, yes: 46
ventricular_tachycardia 0 1.00 FALSE 2 no: 278, yes: 11
angina_pectoris 0 1.00 FALSE 2 no: 279, yes: 10

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
.id 0 1.00 10159.92 93.55 10001.00 10080.00 10159.00 10241.00 10320.00 ▇▇▇▇▇
.time 0 1.00 126.96 85.66 0.00 24.54 200.00 200.00 200.00 ▅▁▁▁▇
.status 0 1.00 0.28 0.45 0.00 0.00 0.00 1.00 1.00 ▇▁▁▁▃
AGE 0 1.00 67.31 13.52 45.00 57.00 67.00 80.00 92.00 ▇▇▇▆▆
CALCIUM 3 0.99 9.50 0.43 8.17 9.21 9.51 9.80 10.74 ▁▅▇▆▁
CREAT 4 0.99 1.23 0.33 0.65 1.00 1.19 1.40 2.54 ▅▇▃▁▁
GGT 2 0.99 59.41 70.68 1.81 21.46 38.30 70.65 775.85 ▇▁▁▁▁
HB 4 0.99 13.59 1.58 10.21 12.43 13.48 14.49 19.82 ▃▇▅▂▁
HCT 3 0.99 41.14 4.72 28.87 37.87 40.72 44.50 56.67 ▂▇▇▃▁
HDL 2 0.99 47.57 15.08 18.61 36.58 45.77 54.98 105.61 ▃▇▃▁▁
LDL 1 1.00 86.48 34.45 27.35 61.72 82.29 103.90 255.75 ▆▇▂▁▁
MAGNES 1 1.00 2.19 0.28 1.46 1.98 2.17 2.38 2.98 ▂▆▇▅▁
POTASS 3 0.99 4.37 0.45 3.26 4.05 4.37 4.66 5.76 ▂▆▇▂▁
SODIUM 3 0.99 138.43 2.77 129.19 136.53 138.41 140.32 147.43 ▁▃▇▃▁
URICAC 5 0.98 7.78 2.28 3.43 6.16 7.42 9.20 20.94 ▆▇▂▁▁
BMI 2 0.99 27.84 4.82 16.70 24.20 28.00 31.05 42.10 ▂▇▇▃▁
BPDIA 4 0.99 71.55 10.18 43.00 64.00 71.00 78.00 100.00 ▁▅▇▅▁
BPSYS 2 0.99 117.78 14.45 83.00 107.50 118.00 128.00 162.00 ▂▆▇▃▁
HR 3 0.99 68.96 11.01 39.00 62.00 69.00 77.00 98.00 ▁▅▇▅▁
WEIGHT 2 0.99 82.23 17.51 36.10 68.75 81.50 96.10 126.20 ▁▇▇▇▂
noise1 0 1.00 -0.03 0.97 -2.76 -0.69 0.03 0.69 2.17 ▁▃▇▆▂
noise2 0 1.00 0.06 1.01 -2.75 -0.53 0.03 0.81 3.20 ▁▅▇▅▁
noise3 0 1.00 0.00 1.01 -2.68 -0.64 -0.02 0.66 3.54 ▂▆▇▂▁
noise4 0 1.00 0.02 0.98 -3.70 -0.63 0.05 0.63 2.81 ▁▂▇▇▁
noise5 0 1.00 0.03 1.00 -3.02 -0.63 0.07 0.71 2.86 ▁▃▇▅▁
noise6 0 1.00 -0.05 0.99 -3.04 -0.66 -0.07 0.65 3.20 ▁▅▇▃▁
noise7 0 1.00 -0.04 1.03 -3.63 -0.76 0.00 0.59 2.87 ▁▃▇▇▁
noise8 0 1.00 -0.07 1.00 -2.53 -0.74 -0.06 0.57 2.82 ▂▆▇▃▁
noise9 0 1.00 0.05 1.00 -2.91 -0.56 0.01 0.70 2.59 ▁▃▇▆▂
noise10 0 1.00 0.05 1.03 -2.81 -0.61 0.13 0.75 3.06 ▁▅▇▅▁

Data initial split

We can start by loading the tidymodels metapackage and splitting our data into training and testing sets

library(tidymodels)

set.seed(123)
#create a single binary split of the data into a training set and testing set
data_split <- rsample::initial_split(data, strata = .status)
#extract  the resulting data 
data_train <- rsample::training(data_split)
data_test <- rsample::testing(data_split)

Pre-processing data

We pre-process the training data and apply exactly the same step to the test data.

  • Imputing the missing data using k-nearest neighboors
  • Normalization
  • Near zero variance features
  • Correlation filter with r= 0.9
#textrecipes contain extra steps for the recipes package for preprocessing text data.
library(textrecipes)
# make a recipe ####
tte_recipe <-
  recipes::recipe(formula = .time + .status ~ ., data = data_train) %>%
  recipes::update_role(.id, new_role = "id") %>%
  recipes::update_role(c(.time, .status), new_role = "outcome") %>%
  recipes::step_impute_knn(recipes::all_predictors(), -.trt) %>%
  recipes::step_naomit(recipes::all_predictors()) %>%
  recipes::step_nzv(recipes::all_predictors(),
                    freq_cut = 95 / 5,
                    unique_cut = 10) %>%
  recipes::step_normalize(recipes::all_numeric_predictors()) %>%
  recipes::step_corr(recipes::all_numeric_predictors(), threshold = 0.9)#%>%


#prepare new data####
prep_tte_recipe <- tte_recipe %>%
  recipes::prep()

prep_data_test <-
  recipes::bake(object = prep_tte_recipe, new_data = data_test)
prep_data_train <- recipes::juice(prep_tte_recipe)


# inspect data ####
data_prep <- prep_data_train %>%
  bind_rows(prep_data_test)
data_prep  %>%
  skimr::skim()
Data summary
Name Piped data
Number of rows 289
Number of columns 36
_______________________
Column type frequency:
factor 7
numeric 29
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
.trt 0 1 FALSE 2 TRT: 153, PLC: 136
SEX 0 1 FALSE 2 M: 148, F: 141
RACE 0 1 FALSE 3 WHI: 217, ASI: 41, BLA: 31
atrial_fibrillation 0 1 FALSE 2 yes: 176, no: 113
myocardial_infarction 0 1 FALSE 2 no: 219, yes: 70
coronary_artery_disease 0 1 FALSE 2 no: 243, yes: 46
ventricular_tachycardia 0 1 FALSE 2 no: 278, yes: 11

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
.id 0 1 10159.92 93.55 10001.00 10080.00 10159.00 10241.00 10320.00 ▇▇▇▇▇
AGE 0 1 -0.03 1.02 -1.71 -0.81 -0.06 0.92 1.82 ▇▇▇▆▆
CALCIUM 0 1 0.00 1.01 -3.16 -0.66 0.04 0.71 2.96 ▁▅▇▆▁
CREAT 0 1 0.03 1.06 -1.88 -0.72 -0.10 0.58 4.31 ▅▇▃▁▁
GGT 0 1 0.05 1.04 -0.80 -0.51 -0.26 0.21 10.63 ▇▁▁▁▁
HCT 0 1 0.00 0.96 -2.50 -0.67 -0.09 0.65 3.16 ▂▇▇▃▁
HDL 0 1 -0.07 0.99 -1.98 -0.78 -0.19 0.43 3.77 ▃▇▃▁▁
LDL 0 1 0.00 1.04 -1.79 -0.74 -0.12 0.53 5.13 ▆▇▂▁▁
MAGNES 0 1 0.03 0.99 -2.49 -0.69 -0.02 0.71 2.79 ▂▆▇▅▁
POTASS 0 1 -0.05 0.95 -2.38 -0.72 -0.02 0.57 2.88 ▂▆▇▂▁
SODIUM 0 1 -0.01 1.04 -3.51 -0.72 -0.02 0.70 3.39 ▁▃▇▃▁
URICAC 0 1 0.01 0.99 -1.89 -0.69 -0.14 0.62 5.74 ▆▇▂▁▁
BMI 0 1 0.02 0.99 -2.28 -0.73 0.03 0.67 2.97 ▂▇▇▃▁
BPDIA 0 1 0.06 1.02 -2.81 -0.70 0.00 0.70 2.91 ▁▅▇▅▁
BPSYS 0 1 0.07 1.02 -2.39 -0.62 0.08 0.79 3.19 ▂▆▇▃▁
HR 0 1 0.05 1.03 -2.76 -0.60 0.06 0.82 2.79 ▁▅▇▅▁
WEIGHT 0 1 0.03 0.99 -2.59 -0.72 -0.01 0.81 2.51 ▁▇▇▇▂
noise1 0 1 -0.02 0.99 -2.81 -0.70 0.04 0.71 2.22 ▁▃▇▆▂
noise2 0 1 -0.03 1.03 -2.89 -0.63 -0.06 0.73 3.16 ▁▅▇▅▁
noise3 0 1 0.05 1.01 -2.62 -0.58 0.04 0.71 3.59 ▂▆▇▂▁
noise4 0 1 0.02 1.01 -3.80 -0.65 0.05 0.65 2.89 ▁▂▇▇▁
noise5 0 1 -0.04 1.01 -3.12 -0.71 0.00 0.65 2.82 ▁▃▇▅▁
noise6 0 1 -0.04 1.00 -3.07 -0.66 -0.07 0.66 3.24 ▁▅▇▃▁
noise7 0 1 -0.06 0.99 -3.51 -0.76 -0.03 0.55 2.73 ▁▃▇▇▁
noise8 0 1 0.05 1.00 -2.41 -0.62 0.05 0.68 2.93 ▂▆▇▃▁
noise9 0 1 -0.02 1.01 -3.01 -0.63 -0.06 0.64 2.54 ▁▃▇▆▂
noise10 0 1 -0.01 1.02 -2.84 -0.66 0.07 0.68 2.97 ▁▅▇▅▁
.time 0 1 126.96 85.66 0.00 24.54 200.00 200.00 200.00 ▅▁▁▁▇
.status 0 1 0.28 0.45 0.00 0.00 0.00 1.00 1.00 ▇▁▁▁▃

Setup model

Construct the survival task

library(mlr3learners)#extend mlr3 package withpopular learners, need it for using ranger
library(mlr3proba)#supports survival analysis
library(mlr3)#learners



# construction of Survival task ####

# First we put the data into an efficient memory data.table
# create instance
data_use <-
  mlr3::DataBackendDataTable$new(
    data = prep_data_train %>%
      dplyr::mutate(.id = as.integer(.id)) %>% data.table::as.data.table(),
    primary_key = ".id"
  )
# Specify the survival task, create new instance
surv_task <- mlr3proba::TaskSurv$new(
  id = "surv_example",
  backend = data_use,
  time = ".time",
  event = ".status",
  type = c("right")
)

Kaplan-Meier curve

# Explore Kaplan-Meier curve
mlr3viz::autoplot(surv_task)

Build the learner (survival RF from ranger package)

#built learner
ranger_lrn <- mlr3::lrn(
  "surv.ranger",
  respect.unordered.factors = "order",
  verbose = FALSE,
  importance = "permutation"
) #Variable importance mode
# Inspect parameters

ranger_lrn$param_set
## <ParamSet>
##                               id    class lower upper nlevels        default
##  1:                    num.trees ParamInt     1   Inf     Inf            500
##  2:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>
##  3:                   importance ParamFct    NA    NA       4 <NoDefault[3]>
##  4:                 write.forest ParamLgl    NA    NA       2           TRUE
##  5:                min.node.size ParamInt     1   Inf     Inf              5
##  6:                      replace ParamLgl    NA    NA       2           TRUE
##  7:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>
##  8:                    splitrule ParamFct    NA    NA       4        logrank
##  9:            num.random.splits ParamInt     1   Inf     Inf              1
## 10:                    max.depth ParamInt  -Inf   Inf     Inf               
## 11:                        alpha ParamDbl  -Inf   Inf     Inf            0.5
## 12:                      minprop ParamDbl  -Inf   Inf     Inf            0.1
## 13:        regularization.factor ParamUty    NA    NA     Inf              1
## 14:      regularization.usedepth ParamLgl    NA    NA       2          FALSE
## 15:                         seed ParamInt  -Inf   Inf     Inf               
## 16:         split.select.weights ParamDbl     0     1     Inf <NoDefault[3]>
## 17:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>
## 18:    respect.unordered.factors ParamFct    NA    NA       3         ignore
## 19: scale.permutation.importance ParamLgl    NA    NA       2          FALSE
## 20:                   keep.inbag ParamLgl    NA    NA       2          FALSE
## 21:                      holdout ParamLgl    NA    NA       2          FALSE
## 22:                  num.threads ParamInt     1   Inf     Inf              1
## 23:                  save.memory ParamLgl    NA    NA       2          FALSE
## 24:                      verbose ParamLgl    NA    NA       2           TRUE
## 25:                    oob.error ParamLgl    NA    NA       2           TRUE
##                               id    class lower upper nlevels        default
##           value
##  1:            
##  2:            
##  3: permutation
##  4:            
##  5:            
##  6:            
##  7:            
##  8:            
##  9:            
## 10:            
## 11:            
## 12:            
## 13:            
## 14:            
## 15:            
## 16:            
## 17:            
## 18:       order
## 19:            
## 20:            
## 21:            
## 22:           1
## 23:            
## 24:       FALSE
## 25:            
##           value

Tuning

Now it’s time to tune!

We will tune the following parameters for random forest:

  • number_of_trees
  • mtry, number of variables to possibly split at in each node.
  • min_node_size

We will use mlr3 library for building a survival random forest and tunning the hyperparameters.

Settings

Set the parameter search space

library(paradox)
search_space <- paradox::ps(
  num.trees = paradox::p_int(lower = 500, upper = 2000),
  mtry = paradox::p_int(
    lower = floor(length(surv_task$col_roles$feature) * 0.1),
    upper = floor(length(surv_task$col_roles$feature) * 0.9)
  ),
  min.node.size = paradox::p_int(lower = 1, upper = 40)
)
search_space
## <ParamSet>
##               id    class lower upper nlevels        default value
## 1:     num.trees ParamInt   500  2000    1501 <NoDefault[3]>      
## 2:          mtry ParamInt     3    29      27 <NoDefault[3]>      
## 3: min.node.size ParamInt     1    40      40 <NoDefault[3]>

Select resampling strategy and performance measure

We need to specify how to evaluate the performance of a trained model. For this, we need to choose a resampling strategy and a performance measure. Here we choose cross-validation and C-index

library(mlr3tuning)
#choose strategy and measure ####
#3-fold cross validation
hout <- mlr3::rsmp("cv", folds = 3)
measure <- mlr3::msr("surv.cindex")
#Terminator that stops after a number of evaluations
evals5 = mlr3tuning::trm("evals", n_evals = 5)

Generate tuning instance

#generate tuning instance, from task, learner, search space, resampling method and measure
instance <- mlr3tuning::TuningInstanceSingleCrit$new(
  task = surv_task,
  learner = ranger_lrn,
  resampling = hout,
  measure = measure,
  search_space = search_space,
  terminator = evals5
)

Initiate tuning

Note 1: actual_tuning chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)

Note 2: We use parallelization, even though dataset small, for training purposes

#packages needed for parallelization
library(doFuture)
library(doRNG)
library(foreach)
tictoc::tic()
# enable parallel processing
doFuture::registerDoFuture()
future::plan(future::multisession, workers =  availableCores() - 1)

# specify seed
doRNG::registerDoRNG(seed = 123)


tuner$optimize(instance)
# disable parallel backend
foreach::registerDoSEQ()

tictoc::toc()

#Uncomment next line to save the outcome
qs::qsave(instance, here::here("Data", "htune_demo.qs"))

Evaluate tuning performance

How did all the possible parameter combinations do?

instance <- qs::qread(here::here("Data", "htune_demo.qs"))
hyparams <- instance$search_space$ids()
perf_data <- instance$archive$data

perf_data %>%
  select(num.trees, mtry,   min.node.size,  surv.harrell_c) %>%
  arrange(desc(surv.harrell_c)) %>%
  mutate(surv.harrell_c = surv.harrell_c %>% round(., digits = 4)) %>%
  kableExtra::kable(escape = FALSE) %>%
  kableExtra::kable_styling(
    bootstrap_options = "striped",
    full_width = FALSE,
    position = "left"
  ) %>%
  kableExtra::column_spec(1, bold = TRUE) %>%
  kableExtra::row_spec(0,
                       bold = TRUE,
                       background = "#00617F",
                       color = "white")
num.trees mtry min.node.size surv.harrell_c
1250 9 10 0.8597
1250 9 30 0.8459
1250 15 30 0.8386
875 28 30 0.8305
875 15 1 0.8264

Final model

Train the final learner

Change hyperparameters to those selected in the tuning step

# adding best hyperparameters
ranger_lrn$param_set$values <- c(
  ranger_lrn$param_set$values,
  perf_data %>%
    select(num.trees, mtry, min.node.size,  surv.harrell_c) %>%
    arrange(desc(surv.harrell_c)) %>%
    select(-surv.harrell_c) %>%
    slice(1)
)

Train the final learner

set.seed(1234)
final_rf <- ranger_lrn$train(task = surv_task)

Performance

Testing

# predict the outcome with the test data
pred_test <- final_rf$predict_newdata(newdata = prep_data_test)
# Define the performance metrics
pred_measures <- suppressWarnings(mlr3::msrs("surv.cindex"))
# Estimate performance
test_performance <- pred_test$score(
  measures = pred_measures,
  task = surv_task,
  learner = final_rf,
  train_set = surv_task$row_ids
) %>%
  tibble::enframe(name = ".metric", value = ".estimate")
#print performance
test_performance %>%
  kableExtra::kable(escape = FALSE) %>%
  kableExtra::kable_styling(
    bootstrap_options = "striped",
    full_width = FALSE,
    position = "left"
  ) %>%
  kableExtra::column_spec(1, bold = TRUE) %>%
  kableExtra::row_spec(0,
                       bold = TRUE,
                       background = "#00617F",
                       color = "white") 
.metric .estimate
surv.harrell_c 0.927572

Training

# predict the outcome with the test data
pred_train <- final_rf$predict_newdata(newdata = prep_data_train)

# Estimate performance
train_performance <- pred_train$score(
  measures = pred_measures,
  task = surv_task,
  learner = final_rf,
  train_set = surv_task$row_ids
) %>%
  tibble::enframe(name = ".metric", value = ".estimate")
# Print performance
train_performance %>%
  kableExtra::kable(escape = FALSE) %>%
  kableExtra::kable_styling(
    bootstrap_options = "striped",
    full_width = FALSE,
    position = "left"
  ) %>%
  kableExtra::column_spec(1, bold = TRUE) %>%
  kableExtra::row_spec(0,
                       bold = TRUE,
                       background = "#00617F",
                       color = "white") 
.metric .estimate
surv.harrell_c 0.9606475

Variable Feature Importance (VIMP)

Lastly, let’s learn about feature importance for this model using the vip package. For a ranger model, we do need to add in the engine importance = “permutation”, in order to compute feature importance.

importance <- final_rf$importance() %>%
  as_tibble_col() %>%
  bind_cols(variables = final_rf$importance() %>% names()) %>%
  relocate(variables)

importance %>%
  DT::datatable(
    rownames = TRUE,
    filter = "top",
    selection = "single",
    extensions = c("Buttons"),
    options = list(
      lengthMenu = c(5, 10, 25, 50),
      pageLength = 5,
      scrollX = TRUE,
      dom = "lfrtBpi",
      buttons = list("excel")
    )
  )
# top 10
vi_nplot <- 10

#plot permutation importance
imp_fr_plt <- importance %>%
  dplyr::arrange(., desc(value)) %>%
  dplyr::slice(1:vi_nplot) %>%
  dplyr::mutate(Sign = as.factor(ifelse(value > 0, "positive", "negative")))

p <- imp_fr_plt %>%
  ggplot(aes(
    y = reorder(variables, value),
    x = value,
    fill = Sign
  )) +
  geom_col() +
  scale_fill_manual(values = c("#00659C", "#930A34")) +
  theme(
    legend.position = "none",
    axis.title.y = element_blank(),
    plot.subtitle = element_text(size = 11),
    plot.title.position = "plot",
    plot.margin = margin(r = 20)
  ) +
  labs(subtitle = paste0("The top ", vi_nplot, " Variables based on Permutation"))


#plot output
p

Explainable AI (PDPs)

We will show partial dependency plots. We will use all available data: train + test

HCT

library(tidyverse)
library(ranger)
rf_model <- final_rf$model
# define time points represented by rank/order they appear
time_points <-
  seq(
    from = 1,
    to = length(rf_model$unique.death.times),
    length.out = 10
  ) %>% round()

Note: PDP predictions chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)

# define feature and grid (based onf feature class)
feat <- "HCT"

feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()
n_grid <- 50


if (feat_cat == "numeric") {
  feat_range <- data_prep %>%
    dplyr::pull(!!feat) %>%
    range()
  
  feat_grid <-
    seq(from = feat_range[1],
        to = feat_range[2],
        length.out = n_grid)
  
} else {
  feat_grid <- data_prep %>%
    dplyr::pull(!!feat) %>%
    levels()
}

# replace corresponding feature values with grid values
data_sets <-  purrr::map(feat_grid,
                         ~ data_prep %>% dplyr::mutate(dplyr::across(
                           tidyselect::all_of(feat),
                           .fns = function(y)
                             .x
                         )))


# calculate predictions for specified time points ####
preds <- map(
  data_sets,
  ~ ranger:::predict.ranger(rf_model, data = .x)$survival %>%
    tibble::as_tibble() %>%
    dplyr::select(all_of(time_points))
)
# calculate PDPs (average (survival probability) for each feature grid value per timepoint)
pdp_data <- purrr::map2(
  preds,
  feat_grid,
  ~ .x %>%
    apply(2, mean) %>%
    tibble::enframe(value = ".value", name = "time_id") %>%
    dplyr::mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
    dplyr::mutate(feat_val1 = .y)
) %>%
  dplyr::bind_rows() %>%
  dplyr::rename({
    {
      feat
    }
  } := feat_val1)

#Uncomment next line to save the outcome
qs::qsave(pdp_data, here::here("Data", "hct_pdp_preds.qs"))

Plot PDPs

feat <- "HCT"
pdp_data <- qs::qread(here::here("Data", "hct_pdp_preds.qs"))
feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()

#create rug for HCT
hct_rug <- data_prep %>%
  dplyr::pull(!!feat)
#plot
p <- pdp_data %>%
  ggplot(aes(x = !!rlang::sym(feat), y = 1 - .value)) + # 1 minus for event probability, not survival prob
  {
    if (feat_cat == "numeric")
      geom_line()
    else
      geom_col()
  } +
  facet_wrap(
    ~ time_id,
    nrow = 2,
    labeller =  ggplot2::labeller(
      time_id  = function(s) {
        rf_model$unique.death.times[as.numeric(s)] %>% round(3)
      },
      # construct time labels within the function
      .default = ggplot2::label_value
    )
  ) +
  ylab("event probability")

p + ggplot2::geom_rug(
  data = hct_rug %>%
    tibble::enframe(),
  mapping = ggplot2::aes(x = value),
  inherit.aes = F,
  sides = "b",
  alpha = 1,
  col = "#B3B3B3"
)

.trt

Calculate PDPs

# define feature and grid (based on feature class)
feat <- ".trt"

feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()
n_grid <- 50


if (feat_cat == "numeric") {
  feat_range <- data_prep %>%
    dplyr::pull(!!feat) %>%
    range()
  
  feat_grid <-
    seq(from = feat_range[1],
        to = feat_range[2],
        length.out = n_grid)
  
} else {
  feat_grid <- data_prep %>%
    dplyr::pull(!!feat) %>%
    levels()
}

# replace corresponding feature values with grid values
data_sets <-  purrr::map(feat_grid,
                         ~ data_prep %>% dplyr::mutate(dplyr::across(
                           tidyselect::all_of(feat),
                           .fns = function(y)
                             .x
                         )))

# calculate predictions for specified time points
preds <- map(
  data_sets,
  ~ ranger:::predict.ranger(rf_model, data = .x)$survival %>%
    tibble::as_tibble() %>%
    dplyr::select(all_of(time_points))
)
# calculate PDPs (average (survival probability) for each feature grid value per timepoint)
pdp_data <- purrr::map2(
  preds,
  feat_grid,
  ~ .x %>%
    apply(2, mean) %>%
    tibble::enframe(value = ".value", name = "time_id") %>%
    dplyr::mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
    dplyr::mutate(feat_val1 = .y)
) %>%
  dplyr::bind_rows() %>%
  dplyr::rename({
    {
      feat
    }
  } := feat_val1)

Plot PDPs

p<-pdp_data %>%
  ggplot(aes(x = !!rlang::sym(feat), y = 1 - .value)) + # 1 minus for event probability, not survival prob
  { if (feat_cat == "numeric") geom_line() else geom_col()} +
  facet_wrap(~ time_id, 
             nrow = 2, 
             labeller =  ggplot2::labeller(time_id  = function(s) {rf_model$unique.death.times[as.numeric(s)] %>% round(3) }, # construct time labels within the function
                                           .default = ggplot2::label_value)) +
  ylab("event probability")
p

.trt and BMI,2-d pdp

Calculate 2D PDPs for BMI,.trt Note: PDP predictions chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)

# define grid ####
feat <- c("BMI", ".trt")

n_grid <- 50

#create range for BMI
bmi_range <- data_prep %>%
  dplyr::pull(BMI) %>%
  range()

bmi_grid <-
  seq(from = bmi_range[1],
      to = bmi_range[2],
      length.out = n_grid)

#get trt levels
trt_grid <- data_prep %>%
  dplyr::pull(.trt) %>%
  levels()


# replace corresponding feature values with grid values
data_sets <-
  tidyr::expand_grid(BMI = bmi_grid, .trt = trt_grid, data_prep %>% select(-c(BMI, .trt)))

# calculate predictions for specified time points

preds <-
  ranger:::predict.ranger(rf_model, data = data_sets)$survival %>%
  tibble::as_tibble() %>%
  dplyr::select(all_of(time_points))

#merge predictions with data set and calculate 2D PDPs (average (survival probability) for each feature combination value per timepoint)
pdp_data  <- data_sets %>% bind_cols(preds) %>%
  filter(BMI %in% bmi_grid) %>%
  pivot_longer(
    c(colnames(preds)),
    names_to = "time_id",
    values_to = ".values"
  )  %>% mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
  group_by(BMI, .trt, time_id) %>% summarise(.value = mean(.values))
#Uncomment next line to save the outcome
qs::qsave(pdp_data, here::here("Data", "bmi_trt_int_pdp_preds.qs"))

Plot PDPs

#rug
bmi_rug <- data_prep %>%
  dplyr::pull(BMI)

pdp_data <- qs::qread(here::here("Data", "bmi_trt_int_pdp_preds.qs"))
#create range for BMI
bmi_range <- data_prep %>%
  dplyr::pull(BMI) %>%
  range()
#plot
p <- pdp_data %>%
  ggplot2::ggplot(aes(
    x = BMI,
    y = 1 - .value,
    color = .trt
  )) + # 1 minus for event probability, not survival prob
  geom_line()    +
  facet_wrap(
    ~ time_id,
    nrow = 2,
    labeller =  ggplot2::labeller(
      time_id  = function(s) {
        rf_model$unique.death.times[as.numeric(s)] %>% round(3)
      },
      # construct time labels within the function
      .default = ggplot2::label_value
    )
  ) +
  ylab("event probability")

p + ggplot2::geom_rug(
  data = bmi_rug %>% tibble::enframe(),
  mapping = ggplot2::aes(x = value),
  inherit.aes = F,
  sides = "b",
  alpha = 1,
  col = "#B3B3B3"
) +
  coord_cartesian(ylim = c(0, 0.5))

Session info

 R version 4.0.1 (2020-06-06)
 Platform: x86_64-pc-linux-gnu (64-bit)
 Running under: Ubuntu 18.04.6 LTS
 
 Matrix products: default
 BLAS:   /opt/R/4.0.1/lib/R/lib/libRblas.so
 LAPACK: /opt/R/4.0.1/lib/R/lib/libRlapack.so
 
 attached base packages:
 [1] stats     graphics  grDevices utils     datasets  methods   base     
 
 other attached packages:
  [1] ranger_0.13.1      mlr3tuning_0.13.0  paradox_0.9.0      mlr3proba_0.4.0   
  [5] mlr3learners_0.4.5 mlr3_0.13.3        textrecipes_0.4.1  yardstick_0.0.8   
  [9] workflowsets_0.0.2 workflows_0.2.3    tune_0.1.5         rsample_0.1.0     
 [13] recipes_0.1.16     parsnip_0.2.1.9001 modeldata_0.1.1    infer_0.5.4       
 [17] dials_0.0.9        scales_1.1.1       broom_0.7.9        tidymodels_0.1.3  
 [21] skimr_2.1.4        qs_0.24.1          tictoc_1.0.1       here_1.0.1        
 [25] kableExtra_1.3.4   forcats_0.5.1      stringr_1.4.0      dplyr_1.0.7       
 [29] purrr_0.3.4        readr_1.4.0        tidyr_1.1.4        tibble_3.1.5      
 [33] ggplot2_3.3.5      tidyverse_1.3.1   
 
To cite R in publications use:

R Core Team (2020). R: A Language and Environment for Statistical Computing. R Foundation for Statistical Computing, Vienna, Austria. https://www.R-project.org/.

To cite the ggplot2 package in publications use:

Wickham H (2016). ggplot2: Elegant Graphics for Data Analysis. Springer-Verlag New York. ISBN 978-3-319-24277-4, https://ggplot2.tidyverse.org.

Program: /home/antigoni_elefsinioti/sdi_ml_course/Demo/Demo_simulated_tte.Rmd


HTML template by Sebastian Voss of Chrestos GmbH & Co. KG on behalf of Biomarker & Data Insights, Bayer AG in 2019